/**
 * Copyright 2021 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "execute_node_builder.h"

#include <algorithm>

#include "execute_graph_builder.h"
namespace ge {
ExecuteNodeBuilder &ExecuteNodeBuilder::SetId(int64_t id) {
  node_id_ = id;
  return *this;
}
ExecuteNodeBuilder &ExecuteNodeBuilder::SetName(const char *name) {
  node_name_ = pool_builder_->AddString(name);
  return *this;
}
ExecuteNodeBuilder &ExecuteNodeBuilder::SetOpDef(PoolOffset op_def) {
  op_def_ = op_def;
  return *this;
}
ExecuteNodeBuilder &ExecuteNodeBuilder::SetInputName(int index, const char *name) {
  if (index < 0) {
    // log error
    error_code_ = FAILED;
    return *this;
  }
  if (name == nullptr) {
    error_code_ = FAILED;
    return *this;
  }
  input_indexes_to_desc_[index].name = pool_builder_->AddString(name);

  return *this;
}
ExecuteNodeBuilder &ExecuteNodeBuilder::SetInputDesc(int index, PoolOffset td) {
  if (index < 0) {
    // log error
    error_code_ = FAILED;
    return *this;
  }
  input_indexes_to_desc_[index].tensor_desc = td;
  return *this;
}

ExecuteNodeBuilder &ExecuteNodeBuilder::SetOutputName(int index, const char *name) {
  if (index < 0) {
    // log error
    error_code_ = FAILED;
    return *this;
  }
  if (name == nullptr) {
    error_code_ = FAILED;
    return *this;
  }
  output_indexes_to_desc_[index].name = pool_builder_->AddString(name);

  return *this;
}
ExecuteNodeBuilder &ExecuteNodeBuilder::SetOutputDesc(int index, PoolOffset td) {
  if (index < 0) {
    // log error
    error_code_ = FAILED;
    return *this;
  }
  output_indexes_to_desc_[index].tensor_desc = td;

  return *this;
}

ExecuteNodeBuilder::ExecuteNodeBuilder(GraphPoolBuilder *pool_builder, PoolOffset node_offset)
    : pool_builder_(pool_builder), node_offset_(node_offset) {}

Status ExecuteNodeBuilder::BuildNode(const GraphPoolReader *pool_reader, ExecuteNode *node) {
  if (error_code_ != SUCCESS) {
    return error_code_;
  }
  CHECK_NOT_NULL(node);

  // build op_desc
  auto &op_desc = node->op_desc_;
  op_desc.id_ = node_id_;
  op_desc.name_ = pool_reader->GetString(node_name_);
  op_desc.op_def_ = pool_reader->GetOpDef(op_def_);
  CHECK_NOT_NULL(op_desc.name_);
  CHECK_NOT_NULL(op_desc.op_def_);

  op_desc.inputs_.resize(input_indexes_to_desc_.size());
  for (const auto &input_index_to_desc : input_indexes_to_desc_) {
    auto index = input_index_to_desc.first;
    op_desc.inputs_[index].index = index;
    op_desc.inputs_[index].name = pool_reader->GetString(input_index_to_desc.second.name);
    op_desc.inputs_[index].td = pool_reader->GetTensorDesc(input_index_to_desc.second.tensor_desc);
    CHECK_NOT_NULL(op_desc.inputs_[index].name);
    CHECK_NOT_NULL(op_desc.inputs_[index].td);
  }

  node->op_desc_.outputs_.resize(output_indexes_to_desc_.size());
  for (const auto &output_index_to_desc : output_indexes_to_desc_) {
    auto index = output_index_to_desc.first;
    op_desc.outputs_[index].index = index;
    op_desc.outputs_[index].name = pool_reader->GetString(output_index_to_desc.second.name);
    op_desc.outputs_[index].td = pool_reader->GetTensorDesc(output_index_to_desc.second.tensor_desc);
    CHECK_NOT_NULL(op_desc.outputs_[index].name);
    CHECK_NOT_NULL(op_desc.outputs_[index].td);
  }

  auto &attrs_def = op_desc.op_def_->attrs_def;
  for (const auto &attr_index_to_value : attr_indexes_to_value_) {
    if (attr_index_to_value.first < 0 || attr_index_to_value.first >= attrs_def.size()) {
      // log error, invalid attr index specified
      return FAILED;
    }
  }

  for (size_t i = 0; i < attrs_def.size(); ++i) {
    op_desc.attr_store_.SetNameAndId(attrs_def[i].name, static_cast<AttrId>(i));
    auto iter = attr_indexes_to_value_.find(static_cast<int>(i));
    if (iter == attr_indexes_to_value_.end()) {
      if (attrs_def[i].e_type == kMustIo) {
        return FAILED;
      }
      if (attrs_def[i].default_creator != nullptr) {
        op_desc.attr_store_.Set(static_cast<int>(i), attrs_def[i].default_creator());
      } else {
        // 没有默认的创建函数
      }
    } else {
      op_desc.attr_store_.Set(static_cast<int>(i), std::move(iter->second));
    }
  }

  return SUCCESS;
}
PoolOffset ExecuteNodeBuilder::GetOffset() const {
  return node_offset_;
}
size_t ExecuteNodeBuilder::GetInputNum() const {
  return input_indexes_to_desc_.size();
}
size_t ExecuteNodeBuilder::GetOutputNum() const {
  return output_indexes_to_desc_.size();
}
ExecuteNodeBuilder &ExecuteNodeBuilder::SetAttr(int index, AnyValue &&value) {
  attr_indexes_to_value_[index] = std::move(value);
  return *this;
}
}  // namespace ge